import pytest
from mlagents.torch_utils import torch

from mlagents.trainers.torch_entities.distributions import (
    GaussianDistribution,
    MultiCategoricalDistribution,
    GaussianDistInstance,
    TanhGaussianDistInstance,
    CategoricalDistInstance,
)


@pytest.mark.parametrize("tanh_squash", [True, False])
@pytest.mark.parametrize("conditional_sigma", [True, False])
def test_gaussian_distribution(conditional_sigma, tanh_squash):
    torch.manual_seed(0)
    hidden_size = 16
    act_size = 4
    sample_embedding = torch.ones((1, 16))
    gauss_dist = GaussianDistribution(
        hidden_size,
        act_size,
        conditional_sigma=conditional_sigma,
        tanh_squash=tanh_squash,
    )

    # Make sure backprop works
    force_action = torch.zeros((1, act_size))
    optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)

    for _ in range(50):
        dist_inst = gauss_dist(sample_embedding)
        if tanh_squash:
            assert isinstance(dist_inst, TanhGaussianDistInstance)
        else:
            assert isinstance(dist_inst, GaussianDistInstance)
        log_prob = dist_inst.log_prob(force_action)
        loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    for prob in log_prob.flatten().tolist():
        assert prob == pytest.approx(-2, abs=0.1)


def test_multi_categorical_distribution():
    torch.manual_seed(0)
    hidden_size = 16
    act_size = [3, 3, 4]
    sample_embedding = torch.ones((1, 16))
    gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)

    # Make sure backprop works
    optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)

    def create_test_prob(size: int) -> torch.Tensor:
        test_prob = torch.tensor(
            [[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)]
        )  # High prob for first action
        return test_prob.log()

    for _ in range(100):
        dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size))))
        loss = 0
        for i, dist_inst in enumerate(dist_insts):
            assert isinstance(dist_inst, CategoricalDistInstance)
            log_prob = dist_inst.all_log_prob()
            test_log_prob = create_test_prob(act_size[i])
            # Force log_probs to match the high probability for the first action generated by
            # create_test_prob
            loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    for dist_inst, size in zip(dist_insts, act_size):
        # Check that the log probs are close to the fake ones that we generated.
        test_log_probs = create_test_prob(size)
        for _prob, _test_prob in zip(
            dist_inst.all_log_prob().flatten().tolist(),
            test_log_probs.flatten().tolist(),
        ):
            assert _prob == pytest.approx(_test_prob, abs=0.1)

    # Test masks
    masks = []
    for branch in act_size:
        masks += [0] * (branch - 1) + [1]
    masks = torch.tensor([masks])
    dist_insts = gauss_dist(sample_embedding, masks=masks)
    for dist_inst in dist_insts:
        log_prob = dist_inst.all_log_prob()
        assert log_prob.flatten()[-1].tolist() == pytest.approx(0, abs=0.001)


def test_gaussian_dist_instance():
    torch.manual_seed(0)
    act_size = 4
    dist_instance = GaussianDistInstance(
        torch.zeros(1, act_size), torch.ones(1, act_size)
    )
    action = dist_instance.sample()
    assert action.shape == (1, act_size)
    for log_prob in (
        dist_instance.log_prob(torch.zeros((1, act_size))).flatten().tolist()
    ):
        # Log prob of standard normal at 0
        assert log_prob == pytest.approx(-0.919, abs=0.01)

    for ent in dist_instance.entropy().flatten().tolist():
        # entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma)
        assert ent == pytest.approx(1.42, abs=0.01)


def test_tanh_gaussian_dist_instance():
    torch.manual_seed(0)
    act_size = 4
    dist_instance = TanhGaussianDistInstance(
        torch.zeros(1, act_size), torch.ones(1, act_size)
    )
    for _ in range(10):
        action = dist_instance.sample()
        assert action.shape == (1, act_size)
        assert torch.max(action) < 1.0 and torch.min(action) > -1.0


def test_categorical_dist_instance():
    torch.manual_seed(0)
    act_size = 4
    test_prob = torch.tensor(
        [[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)]
    )  # High prob for first action
    dist_instance = CategoricalDistInstance(test_prob)

    for _ in range(10):
        action = dist_instance.sample()
        assert action.shape == (1, 1)
        assert action < act_size

    # Make sure the first action as higher probability than the others.
    prob_first_action = dist_instance.log_prob(torch.tensor([0]))

    for i in range(1, act_size):
        assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action
